"""
Functions for computing fundamental and adjoint string tensions
given real, per-L kernel and flip-count arrays.

Compat notes:
- `compute_string_tension(...)` keeps the exact signature and returns a scalar (no change).
- New `compute_tension_with_err(...)` returns (sigma, sigma_err, meta) using
  a reproducible block bootstrap over the KERNEL entries.
"""

from __future__ import annotations
import os
import hashlib
from dataclasses import dataclass
from typing import Sequence, Callable, Dict, Any, Tuple, Optional

import numpy as np


# ---------------------------
# Core physics helpers
# ---------------------------

def fundamental_string_tension(kernel: np.ndarray, b: float, k_exp: float) -> float:
    """
    Compute the fundamental string tension, taking absolute values
    so that fractional exponents never produce NaN.

    σ_fund = b * mean(|kernel|**k_exp)

    Parameters
    ----------
    kernel : numpy.ndarray
        Kernel array (any shape). Will be converted to float and flattened.
    b : float
        Coupling constant.
    k_exp : float
        Exponent applied element-wise to kernel prior to averaging.

    Returns
    -------
    float
        The estimated fundamental string tension.
    """
    arr = np.asarray(kernel, dtype=float).ravel()
    # Warn (softly) if negatives exist; we use abs() which is intended.
    if arr.size:
        min_v, max_v = arr.min(), arr.max()
        if min_v < 0:
            print(f"[WARN] kernel contains negative values (min={min_v}, max={max_v}); using abs()")
    powered = np.abs(arr) ** k_exp
    avg = float(np.mean(powered))
    return b * avg


def adjoint_string_tension(sigma_fund: float, N: int) -> float:
    """
    Compute the adjoint string tension via Casimir scaling.

    Parameters
    ----------
    sigma_fund : float
        Fundamental string tension.
    N : int
        Number of colours in SU(N).

    Returns
    -------
    float
        σ_adj = (2*N^2)/(N^2 - 1) * σ_fund.
    """
    numerator = 2 * (N ** 2)
    denominator = (N ** 2) - 1
    ratio = numerator / denominator
    return ratio * sigma_fund


# ---------------------------
# Compat scalar API (unchanged)
# ---------------------------

def compute_string_tension(
    *,
    b: float,
    k_exp: float,
    n0: float,
    L: int,
    gauge: str,
    volumes: Sequence[int],
    fit_range: Sequence[int],
    kernel_path: str,
    flip_counts_path: str,
) -> float:
    """
    Load per-L kernel and flip-count arrays and compute the string tension.

    NOTE: This keeps the original scalar API and behavior to avoid breaking
    other sims. Flip-counts are verified for existence but not used by this
    estimator (physics is kernel-mean + Casimir).

    Parameters
    ----------
    b : float
        Coupling constant.
    k_exp : float
        Exponent for fundamental tension.
    n0 : float
        Included for API consistency; not used here.
    L : int
        Lattice size.
    gauge : str
        "SU2" or "SU3" for adjoint sweeps.
    volumes : Sequence[int]
        Unused in this simplified model.
    fit_range : Sequence[int]
        Unused in this simplified model.
    kernel_path : str
        Path to .npy file containing the kernel (any shape permitted).
    flip_counts_path : str
        Path to .npy file containing flip-count data (verified for existence).

    Returns
    -------
    float
        The computed string tension for SU2/SU3 via Casimir scaling.
    """
    # Load the kernel array (flatten internally)
    try:
        kernel = np.load(kernel_path)
    except Exception as e:
        raise RuntimeError(f"Failed to load kernel from '{kernel_path}': {e}")

    # Verify flip-counts exist (resolver uses provided path; if missing, try fallbacks)
    fc_resolved = _resolve_flip_counts(L, flip_counts_path)
    try:
        _ = np.load(fc_resolved)
    except Exception as e:
        raise RuntimeError(f"Failed to load flip-counts from '{fc_resolved}': {e}")

    # Optional shape hint (non-fatal)
    expected_len = 2 * (L ** 2)
    if np.size(kernel) != expected_len:
        print(f"[WARN] kernel size {np.size(kernel)} != 2*L^2 ({expected_len}) for L={L}; proceeding with flatten().")

    # Fundamental from kernel, then Casimir to adjoint
    sigma_fund = fundamental_string_tension(kernel, b, k_exp)

    gauge_upper = str(gauge).upper()
    if gauge_upper in ("SU2", "SU3"):
        N = int(gauge_upper[-1])
        return adjoint_string_tension(sigma_fund, N=N)
    else:
        raise ValueError(f"Unsupported gauge channel '{gauge}'; expected 'SU2' or 'SU3'.")


# ---------------------------
# New: reproducible bootstrap API
# ---------------------------

@dataclass
class BootSpec:
    reps: int = 400      # number of bootstrap replicates
    block: int = 0       # 0 => auto-select block length
    seed: int = 1337     # RNG seed for reproducibility


def _auto_block(n: int) -> int:
    """Heuristic: ~30-ish blocks for long vectors; sensible minimums for short."""
    if n <= 0:
        return 1
    if n <= 64:
        return max(2, n // 4)
    if n <= 256:
        return max(4, n // 8)
    return max(8, min(n // 32, n // 3))


def _block_indices(n: int, blk: int, rng: np.random.Generator) -> np.ndarray:
    nb = max(1, n // max(1, blk))
    starts = rng.integers(0, max(1, n - blk + 1), size=nb, dtype=np.int64)
    idx = np.concatenate([np.arange(int(s), int(s) + blk) for s in starts], dtype=np.int64)
    if idx.size > n:
        idx = idx[:n]
    return idx


def _sha256(path: str) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1 << 20), b""):
            h.update(chunk)
    return h.hexdigest()


def _resolve_flip_counts(L: int, explicit_path: Optional[str] = None) -> str:
    """
    Prefer the provided path if it exists; otherwise fall back to:
      1) data/results/vol4_loop_fluctuation_sim/L{L}/flip_counts_L{L}.npy
      2) data/flip_counts/flip_counts_L{L}.npy
    Returns the first existing path (or the explicit path even if missing—caller may want that behavior).
    """
    candidates = []
    if explicit_path:
        candidates.append(explicit_path)
    candidates.append(f"data/results/vol4_loop_fluctuation_sim/L{L}/flip_counts_L{L}.npy")
    candidates.append(f"data/flip_counts/flip_counts_L{L}.npy")
    for p in candidates:
        if os.path.exists(p):
            return p
    # Fallback to explicit path even if missing, so legacy callers see the same error site.
    return explicit_path or candidates[-1]


def compute_tension_with_err(
    *,
    b: float,
    k_exp: float,
    n0: float,
    L: int,
    gauge: str,
    volumes: Sequence[int],
    fit_range: Sequence[int],
    kernel_path: str,
    flip_counts_path: Optional[str] = None,
    # Bootstrap controls (optional; safe defaults)
    bootstrap_reps: int = 400,
    bootstrap_block: int = 0,
    bootstrap_seed: int = 1337,
    # Provenance/guard
    kernel_hash_guard: bool = True,
) -> Tuple[float, float, Dict[str, Any]]:
    """
    Compute adjoint string tension and an uncertainty via block bootstrap
    over the KERNEL entries (the actual averaged quantity).

    Returns
    -------
    (sigma, sigma_err, meta)
      sigma:      adjoint string tension (same scalar as compute_string_tension)
      sigma_err:  bootstrap standard error (same units as sigma)
      meta:       dict with provenance (strategy, reps, block, n, kernel hash, resolved paths)
    """
    # Load kernel and flatten
    try:
        kernel = np.load(kernel_path)
    except Exception as e:
        raise RuntimeError(f"Failed to load kernel from '{kernel_path}': {e}")
    kvec = np.asarray(kernel, dtype=float).ravel()
    n = kvec.size
    if n == 0:
        raise ValueError(f"Kernel at '{kernel_path}' is empty.")

    # Resolve & verify flip-counts (exists in either location)
    fc_resolved = _resolve_flip_counts(L, flip_counts_path)
    try:
        _ = np.load(fc_resolved)
    except Exception as e:
        raise RuntimeError(f"Failed to load flip-counts from '{fc_resolved}': {e}")

    # Optional kernel hash for provenance
    khash = _sha256(kernel_path) if kernel_hash_guard and os.path.exists(kernel_path) else None

    # Define scalar estimator matching the compat path
    def _estimator(subvec: np.ndarray) -> float:
        sigma_fund = fundamental_string_tension(subvec, b, k_exp)
        gg = str(gauge).upper()
        if gg not in ("SU2", "SU3"):
            raise ValueError(f"Unsupported gauge channel '{gauge}'; expected 'SU2' or 'SU3'.")
        N = int(gg[-1])
        return adjoint_string_tension(sigma_fund, N=N)

    # Central estimate on full kernel
    sigma = _estimator(kvec)

    # Bootstrap spec
    spec = BootSpec(reps=int(bootstrap_reps), block=int(bootstrap_block or _auto_block(n)), seed=int(bootstrap_seed))
    rng = np.random.default_rng(spec.seed)

    samples = np.empty(spec.reps, dtype=float)
    for i in range(spec.reps):
        idx = _block_indices(n, spec.block, rng)
        samples[i] = _estimator(kvec[idx])

    sigma_err = float(samples.std(ddof=1))

    meta: Dict[str, Any] = {
        "strategy": "kernel_block_bootstrap",
        "reps": spec.reps,
        "block": spec.block,
        "n": int(n),
        "kernel_sha256": khash,
        "kernel_path": kernel_path,
        "flip_counts_path": fc_resolved,
        "L": int(L),
        "b": float(b),
        "k": float(k_exp),
        "n0": float(n0),
        "gauge": str(gauge),
    }
    return float(sigma), float(sigma_err), meta
